import os.path as osp

import argparse
import torch
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE
from torch_geometric.utils import train_test_split_edges
from torch_geometric.data import GraphSAINTRandomWalkSampler, NeighborSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import subgraph
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import DataLoader
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm
import numpy as np
import ipdb

class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = SAGEConv(in_channels, out_channels)
        self.conv2 = SAGEConv(out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)


class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = SAGEConv(in_channels, out_channels)
        self.conv_mu = SAGEConv(out_channels, out_channels)
        self.conv_logstd = SAGEConv(out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


class LinearEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LinearEncoder, self).__init__()
        self.conv = SAGEConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)


class VariationalLinearEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalLinearEncoder, self).__init__()
        self.conv_mu = SAGEConv(in_channels, out_channels)
        self.conv_logstd = SAGEConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
    

def inference_GAE(model, x_all, subgraph_loader, device, args=None):
    if args.linear == True:
        pbar = tqdm(total=x_all.size(0) * 1)
        pbar.set_description('Evaluating')
    #         for i, conv in enumerate(self.convs):
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
#             ipdb.set_trace()
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            x = model.encode((x, x_target), edge_index)
            xs.append(x.cpu())
            pbar.update(batch_size)

        x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all
    
    elif args.variational == False:
        pbar = tqdm(total=x_all.size(0) * 2)
        pbar.set_description('Evaluating')
        # Two layer SAGEConv for GAE. The code should be improved in the future.
        # First layer:
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            x = model.encoder.conv1((x, x_target), edge_index).relu()
            xs.append(x.cpu())
            pbar.update(batch_size)
        x_all = torch.cat(xs, dim=0)
        
        # Second layer:
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            x = model.encoder.conv2((x, x_target), edge_index)
            xs.append(x.cpu())
            pbar.update(batch_size)
        x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all
    
    elif args.variational == True:
        MAX_LOGSTD = 10
        pbar = tqdm(total=x_all.size(0) * 2)
        pbar.set_description('Evaluating')
        # Two layer SAGEConv for GAE. The code should be improved in the future.
        # First layer:
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            x = model.encoder.conv1((x, x_target), edge_index).relu()
            xs.append(x.cpu())
            pbar.update(batch_size)
        x_all = torch.cat(xs, dim=0)
        
        # Second layer:
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            model.__mu__ = model.encoder.conv_mu((x, x_target), edge_index)
            model.__logstd__ = model.encoder.conv_logstd((x, x_target), edge_index)
            model.__logstd__ = model.__logstd__.clamp(max=MAX_LOGSTD)
            x = model.reparametrize(model.__mu__, model.__logstd__)
            xs.append(x.cpu())
            pbar.update(batch_size)
        x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all
    
@torch.no_grad()
def test_GAE(model, data, subgraph_loader, split_idx, device, args, SAVEPATH=None):
    model.eval()
    z = inference_GAE(model, data.x, subgraph_loader, device, args)
    if SAVEPATH is not None:
        torch.save(z,SAVEPATH)
        if args.Emb_only:
            print('Skip Logistic regression!')
            return None
# '/home/ec2-user/Eli/SSL_baselines/GAE_embedding/BertBase.pt'    
    solver='lbfgs'
    multi_class='auto'
    clf = LogisticRegression(solver=solver, multi_class=multi_class,max_iter=500).fit(z[split_idx['train']].detach().cpu().numpy(),
                                               data.y[split_idx['train']].view(-1).detach().cpu().numpy())
    
    return clf.score(z[split_idx['test']].detach().cpu().numpy(),
                         data.y[split_idx['test']].detach().cpu().numpy())
    
    
def train(model, optimizer, loader, device, args):
    model.train()
    total_loss = 0
    for data in loader:
#         data = train_test_split_edges(data)
        x = data.x.to(device)
        train_pos_edge_index = data.edge_index.to(device)
        optimizer.zero_grad()
        z = model.encode(x, train_pos_edge_index)
        loss = model.recon_loss(z, train_pos_edge_index)
        if args.variational:
            loss = loss + (1 / data.x.shape[0]) * model.kl_loss()
            
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(loader)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--variational', type=int, default=0)
    parser.add_argument('--linear', type=int, default=0)
    parser.add_argument('--save_name', type=str, default='OGB_feature')
    parser.add_argument('--input_feature_path', type=str, default='None')
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--Emb_only', action='store_true')
    args = parser.parse_args()

    if args.variational == 1:
        args.variational = True
    else:
        args.variational = False
        
    if args.linear == 1:
        args.linear = True
    else:
        args.linear = False

    print(args)

    cuda = args.device
    device = torch.device('cuda:'+str(cuda))
    SAVEPATH = './GAE_embedding/Lin_{}-Var_{}-input_{}.pt'.format(args.linear,
                         args.variational,
                         args.save_name)

    dataset = PygNodePropPredDataset(name =  "ogbn-arxiv", root = "../../dataset")
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Replace node features here
    if args.input_feature_path != 'None':
        data.x = torch.tensor(np.load(args.input_feature_path))
        print("Pretrained node features loaded! Path: {}".format(args.input_feature_path))

    out_channels = 768
    num_features = data.x.shape[1]

    if not args.variational:
        if not args.linear:
            model = GAE(GCNEncoder(num_features, out_channels))
        else:
            model = GAE(LinearEncoder(num_features, out_channels))
    else:
        if args.linear:
            model = VGAE(VariationalLinearEncoder(num_features, out_channels))
        else:
            model = VGAE(VariationalGCNEncoder(num_features, out_channels))

            
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    # Orignially batch_size = 20000
    # In linear case, batch_size = 10000
    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=10000,
                                         walk_length=3,
                                         num_steps=30,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1],
                                          batch_size=4096, shuffle=False,
                                          num_workers=12)
    model = model.to(device)

    for epoch in range(1, args.epochs+1):
        loss = train(model, optimizer, loader, device, args)
        print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss))


    acc = test_GAE(model, data, subgraph_loader, split_idx, device, args, SAVEPATH)
    if not args.Emb_only:
        print('Accuracy: {:.2f}'.format(acc*100))

if __name__ == "__main__":
    main()